package v2

import (
	"context"
	"fmt"
	"strconv"
	"sync"
	"time"

	"github.com/shopspring/decimal"

	"github.com/smartcontractkit/chainlink-common/pkg/capabilities"
	"github.com/smartcontractkit/chainlink-common/pkg/settings/limits"
	"github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host"
	sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk"
	"github.com/smartcontractkit/chainlink-protos/cre/go/values"
	protoevents "github.com/smartcontractkit/chainlink-protos/workflows/go/events"
	"github.com/smartcontractkit/chainlink/v2/core/platform"
	"github.com/smartcontractkit/chainlink/v2/core/services/workflows/events"
	"github.com/smartcontractkit/chainlink/v2/core/services/workflows/metering"
	"github.com/smartcontractkit/chainlink/v2/core/services/workflows/store"
)

var _ host.ExecutionHelper = (*ExecutionHelper)(nil)

type ExecutionHelper struct {
	*Engine
	WorkflowExecutionID string
	UserLogChan         chan<- *protoevents.LogLine
	TimeProvider
	SecretsFetcher

	callLimiters map[capCall]limits.BoundLimiter[int]
	mu           sync.Mutex
	callCounts   map[limits.Limiter[int]]int
}

func (c *ExecutionHelper) initLimiters(limiters *EngineLimiters) {
	c.callLimiters = map[capCall]limits.BoundLimiter[int]{
		{"consensus", "Simple"}:         limiters.ConsensusCalls,
		{"consensus", "Report"}:         limiters.ConsensusCalls,
		{"evm", "FilterLogs"}:           limiters.ChainReadCalls,
		{"evm", "WriteReport"}:          limiters.ChainWriteTargets,
		{"http-actions", "SendRequest"}: limiters.HTTPActionCalls,
	}
}

type capCall struct {
	name   string
	method string
}

// CallCapability handles requests generated by the wasm guest
func (c *ExecutionHelper) CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) {
	capName, _, _ := capabilities.ParseID(request.Id)
	limiter, ok := c.callLimiters[capCall{name: capName, method: request.Method}]
	if ok {
		c.mu.Lock()
		if c.callCounts == nil {
			c.callCounts = make(map[limits.Limiter[int]]int)
		}
		cnt := c.callCounts[limiter] + 1
		if err := limiter.Check(ctx, cnt); err != nil {
			c.mu.Unlock()
			return nil, err
		}
		c.callCounts[limiter] = cnt
		c.mu.Unlock()
	}
	free, err := c.capCallsSemaphore.Wait(ctx, 1)
	if err != nil {
		return nil, err
	}
	defer free()
	return c.callCapability(ctx, request)
}

func (c *ExecutionHelper) callCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) {
	// TODO (CAPPL-735): use request.Metadata.WorkflowExecutionId to associate the call with a specific execution
	capability, err := c.cfg.CapRegistry.GetExecutable(ctx, request.Id)
	if err != nil {
		return nil, fmt.Errorf("action capability not found: %w, ", err)
	}

	info, err := capability.Info(ctx)
	if err != nil {
		return nil, fmt.Errorf("capability info not found: %w", err)
	}

	localNode := c.localNode.Load()

	// If the capability info is missing a DON, then
	// the capability is local, and we should use the localNode's DON ID.
	var donID uint32
	if !info.IsLocal {
		if info.DON == nil {
			return nil, fmt.Errorf("remote capability info is missing DON field, ID: %s", info.ID)
		}
		donID = info.DON.ID
	} else {
		donID = localNode.WorkflowDON.ID
	}

	config, err := c.cfg.CapRegistry.ConfigForCapability(ctx, info.ID, donID)
	if err != nil {
		// not explicitly an error case and more relevant (helpful) logging occurs in the metering package
		// debug level should be sufficient here
		c.lggr.Debugf("capability config not found: %s", err)
	}

	meterReport, ok := c.meterReports.Get(c.WorkflowExecutionID)
	if !ok {
		c.lggr.Errorf("no metering report found for %v", c.WorkflowExecutionID)
	}

	meteringRef := strconv.Itoa(int(request.CallbackId))
	spendLimits := []capabilities.SpendLimit{}

	if meterReport != nil {
		// TODO: https://smartcontract-it.atlassian.net/browse/CRE-285 get max spend per step from SDK.
		// TODO: https://smartcontract-it.atlassian.net/browse/CRE-284 parse user max spend for step
		userSpendLimit := decimal.NewNullDecimal(decimal.Zero)
		userSpendLimit.Valid = false

		var openConcurrentCallSlots int
		if openConcurrentCallSlots, err = c.cfg.LocalLimiters.CapabilityConcurrency.Available(ctx); err != nil {
			return nil, err
		}

		if spendLimits, err = meterReport.Deduct(
			meteringRef,
			metering.ByDerivedAvailability(
				userSpendLimit,
				openConcurrentCallSlots,
				info,
				config.RestrictedConfig,
			),
		); err != nil {
			c.cfg.Lggr.Errorw("could not deduct balance for capability request", "capReq", request.Id, "capReqCallbackID", request.CallbackId, "err", err)
		}
	}

	capReq := capabilities.CapabilityRequest{
		Payload:      request.Payload,
		Method:       request.Method,
		CapabilityId: request.Id,
		Metadata: capabilities.RequestMetadata{
			WorkflowID:               c.cfg.WorkflowID,
			WorkflowOwner:            c.cfg.WorkflowOwner,
			WorkflowExecutionID:      c.WorkflowExecutionID,
			WorkflowName:             c.cfg.WorkflowName.Hex(),
			WorkflowDonID:            localNode.WorkflowDON.ID,
			WorkflowDonConfigVersion: localNode.WorkflowDON.ConfigVersion,
			ReferenceID:              strconv.Itoa(int(request.CallbackId)),
			DecodedWorkflowName:      c.cfg.WorkflowName.String(),
			SpendLimits:              spendLimits,
			WorkflowTag:              c.cfg.WorkflowTag,
		},
		Config: values.EmptyMap(),
	}

	c.lggr.Debugw("Executing capability ...", "capID", request.Id, "capReqCallbackID", request.CallbackId, "capReqMethod", request.Method)
	c.metrics.With(platform.KeyCapabilityID, request.Id).IncrementCapabilityInvocationCounter(ctx)
	_ = events.EmitCapabilityStartedEvent(ctx, c.loggerLabels, c.WorkflowExecutionID, request.Id, meteringRef, request.Method)

	execCtx, execCancel, err := c.cfg.LocalLimiters.CapabilityCallTime.WithTimeout(ctx)
	if err != nil {
		return nil, err
	}
	defer execCancel()

	executionStart := time.Now()
	capResp, err := capability.Execute(execCtx, capReq)
	executionDuration := time.Since(executionStart)
	c.metrics.With(platform.KeyCapabilityID, request.Id).UpdateCapabilityExecutionDurationHistogram(ctx, int64(executionDuration.Seconds()))
	if err != nil {
		c.lggr.Debugw("Capability execution failed", "capID", request.Id, "capReqCallbackID", request.CallbackId, "err", err)
		_ = events.EmitCapabilityFinishedEvent(ctx, c.loggerLabels, c.WorkflowExecutionID, request.Id, meteringRef, store.StatusErrored, request.Method, err)
		c.metrics.With(platform.KeyCapabilityID, request.Id).IncrementCapabilityFailureCounter(ctx)
		c.metrics.IncrementTotalWorkflowStepErrorsCounter(ctx)
		return nil, fmt.Errorf("failed to execute capability: %w", err)
	}

	c.lggr.Debugw("Capability execution succeeded", "capID", request.Id, "capReqCallbackID", request.CallbackId)
	_ = events.EmitCapabilityFinishedEvent(ctx, c.loggerLabels, c.WorkflowExecutionID, request.Id, meteringRef, store.StatusCompleted, request.Method, nil)

	if meterReport != nil {
		if err = meterReport.Settle(meteringRef, capResp.Metadata); err != nil {
			c.lggr.Errorw("failed to set metering for capability request", "capReq", request.Id, "capReqCallbackID", request.CallbackId, "err", err)
		}
	}

	return &sdkpb.CapabilityResponse{
		Response: &sdkpb.CapabilityResponse_Payload{
			Payload: capResp.Payload,
		},
	}, nil
}

func (c *ExecutionHelper) GetWorkflowExecutionID() string {
	return c.WorkflowExecutionID
}

func (c *ExecutionHelper) EmitUserLog(msg string) error {
	select {
	case c.UserLogChan <- &protoevents.LogLine{
		NodeTimestamp: time.Now().Format(time.RFC3339Nano),
		Message:       msg,
	}:
		// Successfully sent to channel
	default:
		c.lggr.Warnw("Exceeded max allowed user log messages, dropping")
	}
	return nil
}
